{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('/home/husein/alxlnet/topics.json') as fopen:\n",
    "    topics = set(json.load(fopen).keys())\n",
    "    \n",
    "list_topics = list(topics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xlnet\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "import model_utils\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.function import transformer_textcleaning as cleaning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower= False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]\n",
    "\n",
    "def F(left_train):\n",
    "    tokens_a = tokenize_fn(left_train)\n",
    "    segment_id = [SEG_ID_A] * len(tokens_a)\n",
    "    tokens_a.append(SEP_ID)\n",
    "    tokens_a.append(CLS_ID)\n",
    "    segment_id.append(SEG_ID_A)\n",
    "    segment_id.append(SEG_ID_CLS)\n",
    "    input_mask = [0] * len(tokens_a)\n",
    "    return tokens_a, segment_id, input_mask\n",
    "\n",
    "def XY(data):\n",
    "    \n",
    "    if len(set(data[1]) & topics) and random.random() > 0.2:\n",
    "        t = random.choice(data[1])\n",
    "        label = 1\n",
    "    else:\n",
    "        s = (set(data[1]) | set())\n",
    "        t = random.choice(list(topics - s))\n",
    "        label = 0\n",
    "    X = F(cleaning(data[0]))\n",
    "    Y = F(t)\n",
    "    \n",
    "    return X, Y, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/husein/alxlnet/testset-keyphrase.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.sequence import pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:63: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class Parameter:\n",
    "    def __init__(\n",
    "        self,\n",
    "        decay_method,\n",
    "        warmup_steps,\n",
    "        weight_decay,\n",
    "        adam_epsilon,\n",
    "        num_core_per_host,\n",
    "        lr_layer_decay_rate,\n",
    "        use_tpu,\n",
    "        learning_rate,\n",
    "        train_steps,\n",
    "        min_lr_ratio,\n",
    "        clip,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "\n",
    "num_train_steps = 300000\n",
    "warmup_proportion = 0.1\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "initial_learning_rate = 2e-5\n",
    "\n",
    "kwargs = dict(\n",
    "    is_training = True,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.1,\n",
    "    dropatt = 0.1,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clamp_len = -1,\n",
    ")\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(\n",
    "    json_path = 'xlnet-base-29-03-2020/config.json'\n",
    ")\n",
    "training_parameters = dict(\n",
    "    decay_method = 'poly',\n",
    "    train_steps = num_train_steps,\n",
    "    learning_rate = initial_learning_rate,\n",
    "    warmup_steps = num_warmup_steps,\n",
    "    min_lr_ratio = 0.0,\n",
    "    weight_decay = 0.00,\n",
    "    adam_epsilon = 1e-8,\n",
    "    num_core_per_host = 1,\n",
    "    lr_layer_decay_rate = 1,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.1,\n",
    "    dropatt = 0.1,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clip = 1.0,\n",
    "    clamp_len = -1,\n",
    ")\n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output = 2,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        \n",
    "        self.X_b = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids_b = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks_b = tf.placeholder(tf.float32, [None, None])\n",
    "        \n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        with tf.compat.v1.variable_scope('xlnet', reuse = False):\n",
    "            xlnet_model = xlnet.XLNetModel(\n",
    "                xlnet_config=xlnet_config,\n",
    "                run_config=xlnet_parameters,\n",
    "                input_ids=tf.transpose(self.X, [1, 0]),\n",
    "                seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "                input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "\n",
    "            summary = xlnet_model.get_pooled_out(\"last\", True)\n",
    "            summary = tf.identity(summary, name = 'summary')\n",
    "            self.summary = summary\n",
    "            print(summary)\n",
    "            \n",
    "        with tf.compat.v1.variable_scope('xlnet', reuse = True):\n",
    "            xlnet_model = xlnet.XLNetModel(\n",
    "                xlnet_config=xlnet_config,\n",
    "                run_config=xlnet_parameters,\n",
    "                input_ids=tf.transpose(self.X_b, [1, 0]),\n",
    "                seg_ids=tf.transpose(self.segment_ids_b, [1, 0]),\n",
    "                input_mask=tf.transpose(self.input_masks_b, [1, 0]))\n",
    "            summary_b = xlnet_model.get_pooled_out(\"last\", True)\n",
    "        \n",
    "        vectors_concat = [summary, summary_b, tf.abs(summary - summary_b)]\n",
    "        vectors_concat = tf.concat(vectors_concat, axis = 1)\n",
    "        \n",
    "        self.logits = tf.layers.dense(vectors_concat, dimension_output)\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:220: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:220: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:453: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:460: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:535: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:67: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "Tensor(\"xlnet/summary:0\", shape=(?, 768), dtype=float32)\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'xlnet-base-keyphrase/model.ckpt-240000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from xlnet-base-keyphrase/model.ckpt-240000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = tf.trainable_variables())\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import cosine_similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-2.1187558 ,  0.38729835]], dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = F('Kementerian Pertanian dan Industri Makanan menggalakkan pemain industri pertanian menceburi tanaman penting bagi mengurangkan kebergantungan bahan import dari luar negara')\n",
    "\n",
    "o1 = sess.run(\n",
    "    model.summary,\n",
    "    feed_dict = {\n",
    "        model.X: [X[0]],\n",
    "        model.segment_ids: [X[1]],\n",
    "        model.input_masks: [X[2]],\n",
    "    },\n",
    ")\n",
    "\n",
    "Y = F('tanaman jagung')\n",
    "\n",
    "o2 = sess.run(\n",
    "    model.summary,\n",
    "    feed_dict = {\n",
    "        model.X: [Y[0]],\n",
    "        model.segment_ids: [Y[1]],\n",
    "        model.input_masks: [Y[2]],\n",
    "    },\n",
    ")\n",
    "\n",
    "sess.run(\n",
    "    model.logits,\n",
    "    feed_dict = {\n",
    "        model.X: [X[0]],\n",
    "        model.segment_ids: [X[1]],\n",
    "        model.input_masks: [X[2]],\n",
    "        model.X_b: [Y[0]],\n",
    "        model.segment_ids_b: [Y[1]],\n",
    "        model.input_masks_b: [Y[2]],\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.39292675]], dtype=float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cosine_similarity(o1, o2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'output-xlnet-base-keyphrase/model.ckpt'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'output-xlnet-base-keyphrase/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'Placeholder_3',\n",
       " 'Placeholder_4',\n",
       " 'Placeholder_5',\n",
       " 'Placeholder_6',\n",
       " 'xlnet/model/transformer/r_w_bias',\n",
       " 'xlnet/model/transformer/r_r_bias',\n",
       " 'xlnet/model/transformer/word_embedding/lookup_table',\n",
       " 'xlnet/model/transformer/r_s_bias',\n",
       " 'xlnet/model/transformer/seg_embed',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_0/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_0/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_0/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_0/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_0/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_0/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_1/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_1/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_1/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_1/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_1/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_1/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_2/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_2/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_2/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_2/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_2/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_2/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_3/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_3/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_3/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_3/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_3/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_3/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_4/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_4/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_4/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_4/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_4/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_4/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_5/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_5/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_5/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_5/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_5/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_5/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_6/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_6/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_6/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_6/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_6/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_6/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_7/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_7/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_7/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_7/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_7/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_7/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_8/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_8/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_8/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_8/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_8/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_8/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_9/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_9/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_9/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_9/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_9/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_9/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_10/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_10/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_10/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_10/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_10/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_10/ff/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/q/kernel',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/k/kernel',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/v/kernel',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/r/kernel',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/o/kernel',\n",
       " 'xlnet/model/transformer/layer_11/rel_attn/LayerNorm/gamma',\n",
       " 'xlnet/model/transformer/layer_11/ff/layer_1/kernel',\n",
       " 'xlnet/model/transformer/layer_11/ff/layer_1/bias',\n",
       " 'xlnet/model/transformer/layer_11/ff/layer_2/kernel',\n",
       " 'xlnet/model/transformer/layer_11/ff/layer_2/bias',\n",
       " 'xlnet/model/transformer/layer_11/ff/LayerNorm/gamma',\n",
       " 'xlnet/model_1/sequnece_summary/strided_slice/stack',\n",
       " 'xlnet/model_1/sequnece_summary/strided_slice/stack_1',\n",
       " 'xlnet/model_1/sequnece_summary/strided_slice/stack_2',\n",
       " 'xlnet/model_1/sequnece_summary/strided_slice',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal/shape',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal/mean',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal/stddev',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal/RandomStandardNormal',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal/mul',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/Initializer/random_normal',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel',\n",
       " 'xlnet/model/sequnece_summary/summary/kernel/read',\n",
       " 'xlnet/model/sequnece_summary/summary/bias/Initializer/zeros',\n",
       " 'xlnet/model/sequnece_summary/summary/bias',\n",
       " 'xlnet/model/sequnece_summary/summary/bias/read',\n",
       " 'xlnet/model_1/sequnece_summary/summary/MatMul',\n",
       " 'xlnet/model_1/sequnece_summary/summary/BiasAdd',\n",
       " 'xlnet/model_1/sequnece_summary/summary/Tanh',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/rate',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/Shape',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform/min',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform/max',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform/RandomUniform',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform/sub',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform/mul',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/random_uniform',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/sub/x',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/sub',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/truediv/x',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/truediv',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/GreaterEqual',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/mul',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/Cast',\n",
       " 'xlnet/model_1/sequnece_summary/dropout/dropout/mul_1',\n",
       " 'xlnet/summary',\n",
       " 'xlnet_1/model_1/sequnece_summary/strided_slice/stack',\n",
       " 'xlnet_1/model_1/sequnece_summary/strided_slice/stack_1',\n",
       " 'xlnet_1/model_1/sequnece_summary/strided_slice/stack_2',\n",
       " 'xlnet_1/model_1/sequnece_summary/strided_slice',\n",
       " 'xlnet_1/model_1/sequnece_summary/summary/MatMul',\n",
       " 'xlnet_1/model_1/sequnece_summary/summary/BiasAdd',\n",
       " 'xlnet_1/model_1/sequnece_summary/summary/Tanh',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/rate',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/Shape',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform/min',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform/max',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform/RandomUniform',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform/sub',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform/mul',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/random_uniform',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/sub/x',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/sub',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/truediv/x',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/truediv',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/GreaterEqual',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/mul',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/Cast',\n",
       " 'xlnet_1/model_1/sequnece_summary/dropout/dropout/mul_1',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'logits']"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'summary' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'Identity' not in n.name\n",
    "        and 'Assign' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from output-xlnet-base-keyphrase/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-20-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 165 variables.\n",
      "INFO:tensorflow:Converted 165 variables to const ops.\n",
      "15016 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('output-xlnet-base-keyphrase', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('output-xlnet-base-keyphrase/frozen_model.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "tf.set_random_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "pb = 'output-xlnet-base-keyphrase/frozen_model.pb'\n",
    "\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "    \n",
    "inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2', 'Placeholder_3',\n",
    " 'Placeholder_4',\n",
    " 'Placeholder_5',]\n",
    "outputs = ['xlnet/summary', 'logits']\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           inputs,\n",
    "                                           outputs, transforms)\n",
    "\n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
